{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Video Demo\n",
    "Here, we provide a demo on how to use our model to process a video and visualize the tracking results.\n",
    "\n",
    "We selected an open hpop dance video from the internet to demonstrate our demo. You can also choose other custom videos. **Please note that it is crucial to select the appropriate trained MOTIP weights and configuration for different tracking scenarios.**\n",
    "\n",
    "We process the video on NVIDIA RTX 3080Ti, achieving a nearly real-time tracking."
   ],
   "id": "25c3e69755e855a7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### System Environment\n",
    "1. Modify the root path to the project path.\n",
    "2. Make sure you have a cuda device available."
   ],
   "id": "64dadffc91229559"
  },
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-04-06T04:43:01.735691Z",
     "start_time": "2025-04-06T04:43:00.935612Z"
    }
   },
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "\n",
    "\n",
    "current_file_path = os.path.abspath(\"\")\n",
    "parent_dir = os.path.dirname(current_file_path)\n",
    "sys.path.append(parent_dir)\n",
    "os.chdir(parent_dir)\n",
    "print(f\"Current root path is set to {parent_dir}\")\n",
    "\n",
    "torch_version = torch.__version__\n",
    "cuda_available = torch.cuda.is_available()\n",
    "\n",
    "if not cuda_available:\n",
    "    raise RuntimeError(\"CUDA is not available\")\n",
    "\n",
    "print(f\"Hello! Welcome to use the video process demo. Your torch version is {torch_version} and CUDA is available.\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current root path is set to /home/gaoruopeng/Code/MOTIP-NG/GitHub\n",
      "Hello! Welcome to use the video process demo. Your torch version is 2.4.0 and CUDA is available.\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Prepare your video (.mp4 for example):",
   "id": "6f513017d6a31533"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-06T04:43:05.010231Z",
     "start_time": "2025-04-06T04:43:05.006827Z"
    }
   },
   "cell_type": "code",
   "source": [
    "os.makedirs(\"./outputs/video_process_demo/\", exist_ok=True)\n",
    "video_path = os.path.join(\"./outputs/video_process_demo/\", f\"hpop_dancers.mp4\")\n",
    "output_path = os.path.join(\"./outputs/video_process_demo/\", f\"hpop_dancers_tracking.mp4\")"
   ],
   "id": "40cbc65feda44f45",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "#### [Optional] Download a video from Bilibili if you don't have a video",
   "id": "770162a8ffb32c70"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-06T04:43:11.939185Z",
     "start_time": "2025-04-06T04:43:08.950120Z"
    }
   },
   "cell_type": "code",
   "source": [
    "video_url = \"https://www.bilibili.com/video/BV19mZ2YzERT/\"\n",
    "video_dir = os.path.join(\"./outputs/video_process_demo/\", f\"hpop_dancers\")\n",
    "\n",
    "os.system(f\"you-get -o {video_dir} {video_url}\")\n",
    "files = os.listdir(video_dir)\n",
    "# Search the .mp4 file, change name to \"hpop_dancers.mp4\", move to outputs/video_process_demo/\n",
    "for file in files:\n",
    "    if file.endswith(\".mp4\"):\n",
    "        os.rename(os.path.join(video_dir, file), video_path)\n",
    "        break"
   ],
   "id": "32c5b97fdba544ef",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001B[33myou-get: You will need login cookies for 720p formats or above. (use --cookies to load cookies.txt.)\u001B[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "site:                Bilibili\n",
      "title:               izna《SIGN》练习室舞蹈(Fix ver.)\n",
      "stream:\n",
      "    - format:        \u001B[7mdash-flv480-AVC\u001B[0m\n",
      "      container:     mp4\n",
      "      quality:       清晰 480P avc1.640033\n",
      "      size:          13.5 MiB (14166376 bytes)\n",
      "    # download-with: \u001B[4myou-get --format=dash-flv480-AVC [URL]\u001B[0m\n",
      "\n",
      "Downloading izna《SIGN》练习室舞蹈(Fix ver.).mp4 ...\n",
      " 100% ( 13.5/ 13.5MB) ├████████████████████████████████████████┤[2/2]  253 MB/s\n",
      "Merging video parts... Merged into izna《SIGN》练习室舞蹈(Fix ver.).mp4\n",
      "\n",
      "Downloading izna《SIGN》练习室舞蹈(Fix ver.).cmt.xml ...\n",
      "\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "#### [Optional] Display the video",
   "id": "a05dc2cb4e8cd581"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "from IPython.display import Video\n",
    "\n",
    "Video(video_path, embed=True)"
   ],
   "id": "d8c48fda8f3c9f65",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Build our model",
   "id": "45dd461e8a64cf98"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-06T04:43:19.445797Z",
     "start_time": "2025-04-06T04:43:17.795812Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from utils.misc import yaml_to_dict\n",
    "from configs.util import load_super_config\n",
    "\n",
    "\n",
    "config_path = \"./configs/r50_deformable_detr_motip_dancetrack.yaml\"\n",
    "checkpoint_path = \"./outputs/r50_deformable_detr_motip_dancetrack/r50_deformable_detr_motip_dancetrack.pth\"\n",
    "config = yaml_to_dict(config_path)\n",
    "config = load_super_config(config, config[\"SUPER_CONFIG_PATH\"])\n",
    "dtype = torch.float16       # torch.float32 or torch.float16, we select float16 for faster inference\n",
    "\n",
    "\n",
    "from models.motip import build as build_model\n",
    "from models.misc import load_checkpoint\n",
    "from models.runtime_tracker import RuntimeTracker\n",
    "model, _ = build_model(config)\n",
    "# Load the model weights\n",
    "load_checkpoint(model, checkpoint_path)\n",
    "model.eval()\n",
    "model = model.cuda()\n",
    "if dtype == torch.float16:\n",
    "    model.half()\n",
    "\n",
    "print(\"Model built successfully.\")"
   ],
   "id": "e177068ce0247419",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gaoruopeng/anaconda3/envs/MOTIP-NG/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/gaoruopeng/anaconda3/envs/MOTIP-NG/lib/python3.12/site-packages/torch/functional.py:513: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1720538455419/work/aten/src/ATen/native/TensorShape.cpp:3609.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model built successfully.\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Process the video",
   "id": "dda99259fbf3a3df"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-06T04:46:38.360921Z",
     "start_time": "2025-04-06T04:43:22.912949Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import cv2\n",
    "from utils.nested_tensor import nested_tensor_from_tensor_list\n",
    "from tqdm import tqdm\n",
    "from demo.colormap import get_color\n",
    "\n",
    "\n",
    "def simple_transform(\n",
    "        image, max_shorter, max_longer, image_dtype,\n",
    "):\n",
    "    from torchvision.transforms import functional as F\n",
    "\n",
    "    image = F.to_tensor(image)\n",
    "    image = F.resize(image, size=max_shorter, max_size=max_longer)\n",
    "    image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "    if image_dtype != torch.float32:\n",
    "        image = image.to(image_dtype)\n",
    "    return image.cuda()\n",
    "\n",
    "\n",
    "video_cap = cv2.VideoCapture(video_path)\n",
    "if not video_cap.isOpened():\n",
    "    raise RuntimeError(f\"Failed to open video file: {video_path}\")\n",
    "# Get video properties\n",
    "fps = video_cap.get(cv2.CAP_PROP_FPS)\n",
    "width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
    "height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
    "length = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "print(f\"The video {video_path} seems OK. It has {fps} fps, {width} width and {height} height.\")\n",
    "fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n",
    "video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n",
    "\n",
    "runtime_tracker = RuntimeTracker(\n",
    "    model=model,\n",
    "    sequence_hw=(height, width),\n",
    "    assignment_protocol=\"object-max\",\n",
    "    miss_tolerance=30,\n",
    "    det_thresh=0.5,\n",
    "    newborn_thresh=0.5,\n",
    "    id_thresh=0.2,\n",
    "    dtype=dtype,\n",
    ")\n",
    "\n",
    "for frame_idx in tqdm(range(length), desc=\"Processing video\", unit=\"frame\"):\n",
    "    ret, frame = video_cap.read()\n",
    "    if not ret:\n",
    "        break\n",
    "\n",
    "    # Convert the frame to a tensor\n",
    "    frame_tensor = simple_transform(frame, max_shorter=800, max_longer=1440, image_dtype=dtype)\n",
    "    frame_tensor = nested_tensor_from_tensor_list([frame_tensor])\n",
    "\n",
    "    # Run the tracker on the frame\n",
    "    runtime_tracker.update(frame_tensor)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        track_results = runtime_tracker.get_track_results()\n",
    "\n",
    "    for bbox, obj_id in zip(track_results[\"bbox\"], track_results[\"id\"]):\n",
    "        x, y, w, h = map(int, bbox)\n",
    "        cv2.rectangle(frame, (x, y), (x + w, y + h), get_color(obj_id, rgb=False, use_int=True), 2)\n",
    "        cv2.putText(frame, f\"ID: {obj_id}\", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, get_color(obj_id, rgb=False, use_int=True), 2)\n",
    "\n",
    "    video_writer.write(frame)\n",
    "\n",
    "    frame_idx += 1\n",
    "\n",
    "video_cap.release()\n",
    "video_writer.release()\n",
    "\n",
    "print(f\"Video processing completed. The output video is saved to {output_path}.\")"
   ],
   "id": "1502b31dd6aedf8e",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The video ./outputs/video_process_demo/hpop_dancers.mp4 seems OK. It has 23.976038875306628 fps, 852 width and 480 height.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing video: 100%|██████████| 4019/4019 [03:15<00:00, 20.57frame/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Video processing completed. The output video is saved to ./outputs/video_process_demo/hpop_dancers_tracking.mp4.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "9c9ecf003296290d"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
